from __future__ import annotations
import argparse, csv
from pathlib import Path

TASK_CSV = {
    "math_0shot_cot": "math_0shot_cot.csv",
    "math_4shot_cot": "math_4shot_cot.csv",
    "gsm8k_0shot_cot": "gsm8k_0shot_cot.csv",
    "triviaqa_4shot": "triviaqa_4shot.csv",
    "hellaswag_4shot": "hellaswag_4shot.csv",
    "gsm8k_4shot": "gsm8k_4shot.csv",
    "gsm_plus_5shot_cot": "gsm_plus_5shot_cot.csv",
    "gsm8k_train": "train.csv",
    "gsm8k_test": "test.csv",
    "gsm8k_test": "test.csv",
}

SUBDIR_CANDIDATES = [
    "large_language_monkeys",
    "lm-evaluation-harness",
    "openai/gsm8k",
]


def read_avg(path: Path) -> tuple[float, float]:
    """
    <path> の最終行が 'AVG,<bits_tok>,<bits_byte>' なら値を返す。
    取れなければ (-1., -1.)。
    """
    if not path.exists():
        return -1.0, -1.0
    try:
        *_, last = path.read_text().strip().splitlines()
        _, b_tok, b_byte = last.split(",")
        return float(b_tok), float(b_byte)
    except Exception:
        return -1.0, -1.0


def try_find_csv(ckpt: Path, fname: str) -> Path | None:
    """
    ckpt ディレクトリ直下 or サブディレクトリ候補に fname があるか探す。
    見つかれば Path を返し、無ければ None。
    """
    for sub in SUBDIR_CANDIDATES:
        p = ckpt / sub / fname
        if p.exists():
            return p
    return None


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--list",
        type=Path,
        required=True,
        help="モデル ckpt パスを 1 行ずつ列挙した txt",
    )
    ap.add_argument(
        "--out",
        type=Path,
        default="aggregate_scores.csv",
        help="集計結果を書き出す CSV",
    )
    args = ap.parse_args()

    ckpt_paths = [
        Path(l.strip()) for l in args.list.read_text().splitlines() if l.strip()
    ]
    ckpt_paths = [
        Path("results") / Path(l.strip().lstrip("/"))
        for l in args.list.read_text().splitlines()
        if l.strip()
    ]
    # 見出し行を作成
    header = ["ckpt_path"]
    for task in TASK_CSV:
        header += [f"{task}_bits_tok", f"{task}_bits_byte"]

    rows: list[list[str]] = [header]

    for ckpt in ckpt_paths:
        row: list[str] = [str(ckpt)]

        for task, fname in TASK_CSV.items():
            fpath = try_find_csv(ckpt, fname)
            b_tok, b_byte = read_avg(fpath) if fpath else (-1.0, -1.0)

            row += [
                f"{b_tok:.4f}" if b_tok >= 0 else "-1",
                f"{b_byte:.4f}" if b_byte >= 0 else "-1",
            ]

        rows.append(row)

    # 書き出し
    with args.out.open("w", newline="") as f:
        csv.writer(f).writerows(rows)
    print(f"✓ summary CSV written to {args.out.resolve()}")


if __name__ == "__main__":
    main()
